import os
from package.model.XDUnet import XDUnet
from package.util.loadData import data
from package.util.preprocess import preprocessing
from package.util.function  import utils
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from keras.optimizers import Adam

# hyper parameters
topchannels = 64
input_width = 160
input_height= 160
input_depth = None
n_classes   = 7
batch_size  = 8
max_epochs  = 600               
target      = 'Thoracic_OAR'
key         = '2DUnet'

model_dict = {
    '2DUnet':'Unet2D',
    '3DUnet':'Unet3D'
}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

nets  = XDUnet()
model = getattr(nets,model_dict[key])(topchannels,n_classes,input_width,input_height)

eva_list = ['global_dice']

metrics  = []

for item in eva_list:
    metrics.append(getattr(utils(),item))

model.compile(loss='categorical_crossentropy',
          optimizer=Adam(lr=3.0e-4),
          metrics=metrics)

# 假设为空文件吧
if not os.path.exists('./build'):
    os.mkdir('./build')

if not os.path.exists('./build/checkpoints'):
    os.mkdir('./build/checkpoints')

if not os.path.exists('./build/checkpoints/' + target):
    os.mkdir('./build/checkpoints/' + target)

epoch_begin = 0
if(len(os.listdir('./build/checkpoints/{}'.format(target))) != 0):
    model.load_weights(os.path.join('./build/checkpoints/{}'.format(target),utils().get_new('./build/checkpoints/{}'.format(target))[0]))
    epoch_begin = int(utils().get_new('./build/checkpoints/{}'.format(target))[0].split('-')[3])

x_train,x_valid,y_train,y_valid = data('{}/dataset'.format(os.getcwd()), 4, input_width, input_height).trainTest(0.1,n_classes) # 3->Lung_GTV,4->Thoracic

generator = preprocessing().generatorGet()

callbacks = [
    ReduceLROnPlateau(monitor='loss', factor=0.5, patience=3, mode='max',
                      min_delta=0.005, cooldown=2, verbose=1, min_lr=1e-10),
    EarlyStopping(monitor='val_global_dice', min_delta=0.001, mode='max',
                  verbose=1, patience=5),
    ModelCheckpoint(filepath='./build/checkpoints/%s/%s-%s-%s-{epoch:03d}-{val_global_dice:05f}.hdf5'%(target, key, input_height,input_width),
                    verbose=True,
                    save_best_only=True,
                    monitor='val_global_dice',
                    mode='max'),
]

model.fit_generator(generator.flow(x_train,y_train,batch_size=batch_size),
                    validation_data=(x_valid, y_valid),
                    steps_per_epoch=None,
                    shuffle=True,               # 打乱训练数据
                    epochs=max_epochs,
                    validation_steps=100,
                    callbacks=callbacks,
                    initial_epoch=epoch_begin)